feat: Add GLA (Gated Linear Attention) Forward Operator (L2)#2
Open
superAngGao wants to merge 8 commits intomainfrom
Open
feat: Add GLA (Gated Linear Attention) Forward Operator (L2)#2superAngGao wants to merge 8 commits intomainfrom
superAngGao wants to merge 8 commits intomainfrom
Conversation
fc3c7ab to
5f1e0c7
Compare
Implements chunked GLA forward pass with: - Stage 1+2 (PyTorch): within-chunk gate cumsum + inter-chunk hidden state recurrence - Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT] - Stage 4 (TileLang): output combining inter-chunk and intra-chunk contributions Files added: - tileops/kernels/gla/gla_fwd.py -- GLAFwdKernel (sm90a) - tileops/kernels/gla/__init__.py - tileops/ops/gla.py -- GLAFwdOp - tests/ops/test_gla.py -- 7 test cases (fp16 + bf16, with/without initial_state) Closes tile-ai#213 Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add seq_len % chunk_size == 0 assertion in GLAFwdOp to prevent OOB writes in TileLang kernels on non-divisible sequence lengths - Cast k/v to float32 per-chunk in GLAFwdKernel.forward to reduce peak memory usage - Fix k_adj formula in ref_gla_fwd to use log-space subtraction (matching GLAFwdKernel) instead of division with clamp - Add test_gla_fwd_non_divisible_seq_len to verify the assertion fires - Add skill.md files for create-new-kernel, create-new-op, create-new-op-test, creating-pull-request, migrating-new-op Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…on skill, add YAML frontmatter and auto-invoke to all skills Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…entions - Single @T.prim_func with 4 @T.macro stages in one T.Serial(num_chunks) loop - Stages run in order 1→3→4→2 so stage4 reads pre-decay h_s before stage2 updates it - Hoist all shared buffers into _main and pass as parameters to eliminate duplicate allocations (stays within 232448 byte optin limit) - Move shape lists inside _gla_fwd_func so outer closure only captures serializable scalars (fixes autotuner assertion) - Add self.kernel assignment in __init__ to support autotune - Fix custom_op namespace to top:: and add autotune_configs - forward() only allocates buffers and calls wrapper; no PyTorch compute Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
5f1e0c7 to
b2783e1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements Gated Linear Attention (GLA) forward pass as a new L2 operator (Kernel + Op).
Closes tile-ai#213
Algorithm
Chunked GLA forward in 4 stages:
h [B,NT,H,K,V]with gated decayA [B,T,H,BT]with gated QKo = scale*(q*exp(g_cs))@h + A@vFiles Changed
tileops/kernels/gla/gla_fwd.pyGLAFwdKernel— TileLang stages 3 & 4, sm90atileops/kernels/gla/__init__.pytileops/ops/gla.pyGLAFwdOp— Op wrappertileops/ops/__init__.pyGLAFwdOptests/ops/test_gla.pyTest Results
Reference
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py
Checklist
__init__.pyexports synchronized